import csv
import sys
import math
import ols
from numpy import *
from numpy.random import randn
# stores for each genotype the total number of datapoint for it
totaldic = {}

# a list of dictionaries where each phenotype has a dictionary 
# for which genotypes are mapped to how many datapoints they have
datalist = []

# the conditional probability of a phenotype given a genotype

def cond_prob(phenotype, genotype):
    if datalist[phenotype].has_key(genotype):
        return  1.0 * datalist[phenotype][genotype]/totaldic[genotype]
    else:
        return 0

# load data from a csv file with a specified number of phenotypes
# store totaldic and datalist so probabilities can be calculated
# also deletes any old stored info

def load_data(filename, numphenotypes):
    # clear totaldic and datalist
#    datalist = []
 #   totaldic = {}
    
    # create a reader object to go throug the csv
    dataread = csv.reader(open(filename, "rb"))

    # store the rownumber
    rownum = 0;

    counter = 0

    # fill datalist with a dictionary for each phenotype
    while counter < numphenotypes:
        datalist.append({})
        counter += 1

    # read in data from rows
    for row in dataread:      
        
        # store the header as the first row
        if rownum == 0:
            header = row

        # for all other rows go through columns and store data    
        else:
            column = 0
            stringkey = "";

            # define the key by the genotypes concatenated
            for col in row:
                stringkey += col;
                column += 1
                
            # adds to the genotype keys in the phenotype's dict
            if datalist[int(col)].has_key(stringkey[:-1]):
                datalist[int(col)][stringkey[:-1]]+=1
            else:
                datalist[int(col)][stringkey[:-1]]=1;

            # adds to the phenotype key mapping in total data
            if totaldic.has_key(stringkey[:-1]):
                totaldic[stringkey[:-1]] += 1
            else:
                totaldic[stringkey[:-1]] = 1
        rownum += 1
    return [datalist, totaldic]

def str_to_arr(phenotype):
    genotype_array = array([])
    i = 0
    while i < len(phenotype):
        genotype_array = append(genotype_array, int(phenotype[i:i+1]))
        i += 1
    return genotype_array
                               
    

#def logistic_reg(filename, numphenotypes):
    
          
        
if len(sys.argv) < 1:
    print "please specify a file name"
    sys.exit(1)
        
# read in filename from command line and open it to read
filename = sys.argv[1]

loaded_data = load_data(filename, 2)

#datalist = loaded_data[0]
#totaldic = loaded_data[1]


data = randn(4,3)
x= data[:,1:]
y=data[:,0]
sentinel = 0
for combos in  datalist[0].keys():
    prob = cond_prob(0, combos)
    logit_prob = math.log(prob) - math.log(1-prob)
    if sentinel == 0:
        xarr = [str_to_arr(combos)]
        yarr = [logit_prob]
    else:
        xarr = append(xarr, [str_to_arr(combos)],axis=0)
        yarr = append(yarr, [logit_prob], axis=0)
    sentinel += 1
print yarr
print y
print xarr
print x
mymodel = ols.ols(yarr,xarr,'yarr',['x1','x2'])
print mymodel.p

